{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RLHF Implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import packages and libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "tqdm.pandas()\n",
    "from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead,RewardTrainer\n",
    "from trl.core import LengthSampler\n",
    "import random\n",
    "from datasets import Dataset, load_dataset\n",
    "from transformers import AutoModelForSequenceClassification,AutoTokenizer,TrainingArguments,pipeline\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conifguring the model to be finetuned using RL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = PPOConfig(\n",
    "    model_name=\"gpt2\",\n",
    "    learning_rate=1.41e-5)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(config, input_min_text_length=2, input_max_text_length=200):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "    df = pd.read_csv(\"./../input/feedback.csv\")\n",
    "    ds = Dataset.from_pandas(df)\n",
    "    ds = ds.rename_columns({\"question\": \"review\"})\n",
    "    input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
    "    def tokenize(sample):\n",
    "        sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
    "        sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
    "        return sample\n",
    "    ds = ds.map(tokenize, batched=False)\n",
    "    ds.set_format(type=\"torch\")\n",
    "    return ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/268 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = build_dataset(config)\n",
    "def collator(data):\n",
    "    return dict((key, [d[key] for d in data]) for key in data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'review': ['Can Maximo Visual Inspection run on prem?\\u200b\\u200b\\u200b',\n",
       "  'What is watson knowledge catalog?',\n",
       "  'What is watson knowledge catalog?',\n",
       "  'Can Instana use OpenTelemetry trace data?\\u200b\\u200b\\u200b\\u200b\\u200b',\n",
       "  'format it on a table',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?',\n",
       "  'What Watson Orchestrate is for?',\n",
       "  'What Watson Orchestrate is for?',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?'],\n",
       " 'answer': [' Answer: Yes, Maximo Visual Inspection can be run on premise. ',\n",
       "  ' Answer: The IBM Watson Knowledge Catalog is a data catalog that serves as a single version of the truth for different users including data engineers, business analysts, data analysts, data scientists, and data citizens. Users can gain access to data they can trust, govern, curate, share and manage within an organization. Beyond that, you can add data policies and rules around your data to ensure your information doesnt get into the wrong hands and is compliant. Take control of your data by protecting your sensitive information and tracing the lineage of your data. You should always be certain of where the data came from and how it is being used in your organization. The data governance, data quality, and policy management capabilities within the IBM Watson Knowledge Catalog will help make your data business-ready. The journey to AI starts here, with clean and trusted data. Is your data business-ready? Keep reading to find tips on how to leverage the IBM Watson Knowledge Catalog to get trusted data. One of the',\n",
       "  ' Answer: The IBM Watson Knowledge Catalog is a data catalog that serves as a single version of the truth for different users including data engineers, business analysts, data analysts, data scientists, and data citizens. Users can gain access to data they can trust, govern, curate, share and manage within an organization. Beyond that, you can add data policies and rules around your data to ensure your information doesnt get into the wrong hands and is compliant. Take control of your data by protecting your sensitive information and tracing the lineage of your data. You should always be certain of where the data came from and how it is being used in your organization. The data governance, data quality, and policy management capabilities within the IBM Watson Knowledge Catalog will help make your data business-ready. The journey to AI starts here, with clean and trusted data. Is your data business-ready? Keep reading to find tips on how to leverage the IBM Watson Knowledge Catalog to get trusted data. One of the',\n",
       "  ' Answer: Yes, Instana can ingest OpenTelemetry trace data. Instana supports OpenTelemetry trace data from any OpenTelemetry-compatible application. Instana can ingest trace data from any application that is instrumented with OpenTelemetry. OpenTelemetry is an open-source standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is a standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is a standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is an open-source standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is a standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is an open-source standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is an open-source standard for instrumenting, collecting, and aggregating',\n",
       "  'with the following columns. The first column is the id, the second column is the name, the third column is the phone number, the fourth column is the lazy boolean value, and the fifth column is the married boolean value. Answer: The following SQL statement creates a table with the required columns and inserts the sample data from Listing 3 into the table. The SQL statement uses the JSON_TABLE function to extract the required information from the JSON object. The JSON_TABLE function is used in the FROM clause of the SQL statement. The SQL statement uses a JSON path expression to specify the columns of the table. The JSON path expression is a sequence of SQL/JSON path expressions separated by the comma operator. The SQL/JSON path expression is a sequence of SQL/JSON path expressions separated by the comma operator. The SQL/JSON path expression is a sequence of SQL/JSON path expressions separated by the comma operator. The SQL/JSON path expression is a sequence of SQL/JSON path expressions separated by the',\n",
       "  ' Answer: Watson Orchestrate is a cloud-native, open-source, and open API orchestration platform that accelerates the development and deployment of cloud-native applications and workflows. It is a modern orchestration platform that is built on Kubernetes and is designed to manage and orchestrate complex workflows across hybrid cloud environments. It is a cloud-native orchestration platform that is designed to manage and orchestrate complex workflows across hybrid cloud environments. IBM Integration, IBM Business Automation, and IBM IT Automation products are available on AWS. IBM Integration products on AWS: IBM Integration Bus, IBM Integration Bus for z/OS, IBM Integration Bus for z/OS on Linux, IBM Integration Bus for z/OS on Linux on Red Hat OpenShift, IBM Integration Bus for z/OS on Linux on Red Hat OpenShift on IBM Cloud, IBM Integration Bus for z/OS on Linux on Red Hat OpenShift on IBM Cloud on Red Hat OpenShift, IBM Integration Bus',\n",
       "  ' Answer: Watson Orchestrate is a cloud-native, open-source orchestration platform that enables you to automate and manage complex workflows across hybrid cloud environments. It is a fully managed service that can be deployed in AWS, Azure, and GCP. It is built on Kubernetes and can be used with any cloud-native application. It can be used to orchestrate workflows across multiple cloud-native applications, including IBM Cloud Pak for Integration, IBM Cloud Pak for Data, IBM Cloud Pak for Applications, and IBM Cloud Pak for Multicloud Management. It can also be used to orchestrate workflows across multiple cloud-native applications, including IBM Cloud Pak for Integration, IBM Cloud Pak for Data, IBM Cloud Pak for Applications, and IBM Cloud Pak for Multicloud Management. Watson Orchestrate is a cloud-native, open-source orchestration platform that enables you to automate and manage complex workflows across hybrid cloud environments. It is a fully managed service that can be deployed in',\n",
       "  ' : What is a data fabric? : What is a data lake? ',\n",
       "  ' Answer: Watson Orchestrate is a cloud-based service that helps you automate your data science and machine learning workflows. ',\n",
       "  ' Answer: Watson Orchestrate is a service that provides a unified orchestration platform for automating the deployment, management, and governance of cloud-native applications. It is a cloud-native, open-source, multi-cloud orchestration platform that can be used to automate the deployment, management, and governance of cloud-native applications. It is a cloud-native, open-source, multi-cloud orchestration platform that can be used to automate the deployment, management, and governance of cloud-native applications. It is built on Kubernetes and is designed to work with any cloud provider. It is built on Kubernetes and is designed to work with any cloud provider. It is a cloud-native, open-source, multi-cloud orchestration platform that can be used to automate the deployment, management, and governance of cloud-native applications. It is a cloud-native, open-source, multi-cloud orchestration platform that can be used to automate the deployment, management, and governance of cloud-native applications. It is'],\n",
       " 'feedback': tensor([2., 4., 3., 4., 1., 2., 2., 1., 3., 3.]),\n",
       " 'input_ids': [tensor([ 6090,  5436, 25147, 15612, 47115,  1057,   319,  4199,    30, 39009,\n",
       "           9525]),\n",
       "  tensor([ 2061,   318,   266, 13506,  3725, 18388,    30]),\n",
       "  tensor([ 2061,   318,   266, 13506,  3725, 18388,    30]),\n",
       "  tensor([ 6090,  2262,  2271,   779,  4946, 31709, 41935, 12854,  1366,    30,\n",
       "          39009, 39009,  9525]),\n",
       "  tensor([18982,   340,   319,   257,  3084]),\n",
       "  tensor([10919,   389,   262,  5400,  1022, 14959, 30369, 23104,   290,   584,\n",
       "          19764, 22771,  4899,    30]),\n",
       "  tensor([10919,   389,   262,  5400,  1022, 14959, 30369, 23104,   290,   584,\n",
       "          19764, 22771,  4899,    30]),\n",
       "  tensor([ 2061, 14959, 30369, 23104,   318,   329,    30]),\n",
       "  tensor([ 2061, 14959, 30369, 23104,   318,   329,    30]),\n",
       "  tensor([10919,   389,   262,  5400,  1022, 14959, 30369, 23104,   290,   584,\n",
       "          19764, 22771,  4899,    30])],\n",
       " 'query': ['Can Maximo Visual Inspection run on prem?\\u200b\\u200b\\u200b',\n",
       "  'What is watson knowledge catalog?',\n",
       "  'What is watson knowledge catalog?',\n",
       "  'Can Instana use OpenTelemetry trace data?\\u200b\\u200b\\u200b\\u200b\\u200b',\n",
       "  'format it on a table',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?',\n",
       "  'What Watson Orchestrate is for?',\n",
       "  'What Watson Orchestrate is for?',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?']}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuring PPO Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n",
    "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = ppo_trainer.accelerator.device\n",
    "if ppo_trainer.accelerator.num_processes == 1:\n",
    "    device = 0 if torch.cuda.is_available() else \"cpu\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Trained Reward Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "rm_model_trained = AutoModelForSequenceClassification.from_pretrained(\"./../output/reward_model\")\n",
    "rm_tokenizer_trained = AutoTokenizer.from_pretrained(\"./../output/reward_model\")\n",
    "\n",
    "if rm_tokenizer_trained.pad_token is None:\n",
    "    rm_tokenizer_trained.pad_token = rm_tokenizer_trained.eos_token\n",
    "    rm_model_trained.config.pad_token_id = rm_model_trained.config.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tuning using PPO "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]/tmp/ipykernel_1007370/2770970298.py:29: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  rewards = [torch.tensor(i) for i in outputs.logits]\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/data/rlhf/miniconda3/envs/ak_rlhf/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:1067: UserWarning: The average ratio of batch (11.81) exceeds threshold 10.00. Skipping batch.\n",
      "  warnings.warn(\n",
      "/data/rlhf/miniconda3/envs/ak_rlhf/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:1067: UserWarning: The average ratio of batch (11.86) exceeds threshold 10.00. Skipping batch.\n",
      "  warnings.warn(\n",
      "/data/rlhf/miniconda3/envs/ak_rlhf/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:1067: UserWarning: The average ratio of batch (12.30) exceeds threshold 10.00. Skipping batch.\n",
      "  warnings.warn(\n",
      "/data/rlhf/miniconda3/envs/ak_rlhf/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:1067: UserWarning: The average ratio of batch (57.29) exceeds threshold 10.00. Skipping batch.\n",
      "  warnings.warn(\n",
      "1it [01:30, 90.60s/it]\n"
     ]
    }
   ],
   "source": [
    "output_min_length = 4\n",
    "output_max_length = 16\n",
    "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
    "\n",
    "\n",
    "generation_kwargs = {\n",
    "    \"min_length\": -1,\n",
    "    \"top_k\": 0.0,\n",
    "    \"top_p\": 1.0,\n",
    "    \"do_sample\": True,\n",
    "    \"pad_token_id\": tokenizer.eos_token_id,\n",
    "}\n",
    "\n",
    "\n",
    "for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n",
    "    query_tensors = batch[\"input_ids\"]\n",
    "\n",
    "    response_tensors = []\n",
    "    for query in query_tensors:\n",
    "        gen_len = output_length_sampler()\n",
    "        generation_kwargs[\"max_new_tokens\"] = gen_len\n",
    "        response = ppo_trainer.generate(query, **generation_kwargs)\n",
    "        response_tensors.append(response.squeeze()[-gen_len:])\n",
    "    batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n",
    "\n",
    "    text = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n",
    "    encoding = rm_tokenizer_trained(text, return_tensors=\"pt\",padding='max_length',truncation=True)\n",
    "    outputs = rm_model_trained(**encoding)\n",
    "    rewards = [torch.tensor(i) for i in outputs.logits]\n",
    "\n",
    "    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n",
    "    ppo_trainer.log_stats(stats, batch, rewards)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ak_rlhf",
   "language": "python",
   "name": "ak_rlhf"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
